{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-class classification"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Classification is about predicting an outcome from a fixed list of classes. The prediction is a probability distribution that assigns a probability to each possible outcome.\n",
    "\n",
    "A labeled classification sample is made up of a bunch of features and a class. The class is a usually a string or a number in the case of multiclass classification. We'll use the image segments dataset as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:17.633774Z",
     "iopub.status.busy": "2023-12-04T17:47:17.633578Z",
     "iopub.status.idle": "2023-12-04T17:47:18.059839Z",
     "shell.execute_reply": "2023-12-04T17:47:18.059524Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Image segments classification.\n",
       "\n",
       "This dataset contains features that describe image segments into 7 classes: brickface, sky,\n",
       "foliage, cement, window, path, and grass.\n",
       "\n",
       "    Name  ImageSegments                                                     \n",
       "    Task  Multi-class classification                                        \n",
       " Samples  2,310                                                             \n",
       "Features  18                                                                \n",
       " Classes  7                                                                 \n",
       "  Sparse  False                                                             \n",
       "    Path  /Users/max/projects/online-ml/river/river/datasets/segment.csv.zip"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import datasets\n",
    "\n",
    "dataset = datasets.ImageSegments()\n",
    "dataset"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This dataset is a streaming dataset which can be looped over."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:18.062123Z",
     "iopub.status.busy": "2023-12-04T17:47:18.061924Z",
     "iopub.status.idle": "2023-12-04T17:47:18.085783Z",
     "shell.execute_reply": "2023-12-04T17:47:18.085481Z"
    }
   },
   "outputs": [],
   "source": [
    "for x, y in dataset:\n",
    "    pass"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at the first sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:18.087624Z",
     "iopub.status.busy": "2023-12-04T17:47:18.087509Z",
     "iopub.status.idle": "2023-12-04T17:47:18.098210Z",
     "shell.execute_reply": "2023-12-04T17:47:18.097952Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'region-centroid-col': 218,\n",
       " 'region-centroid-row': 178,\n",
       " 'short-line-density-5': 0.11111111,\n",
       " 'short-line-density-2': 0.0,\n",
       " 'vedge-mean': 0.8333326999999999,\n",
       " 'vegde-sd': 0.54772234,\n",
       " 'hedge-mean': 1.1111094,\n",
       " 'hedge-sd': 0.5443307,\n",
       " 'intensity-mean': 59.629630000000006,\n",
       " 'rawred-mean': 52.44444300000001,\n",
       " 'rawblue-mean': 75.22222,\n",
       " 'rawgreen-mean': 51.22222,\n",
       " 'exred-mean': -21.555555,\n",
       " 'exblue-mean': 46.77778,\n",
       " 'exgreen-mean': -25.222220999999998,\n",
       " 'value-mean': 75.22222,\n",
       " 'saturation-mean': 0.31899637,\n",
       " 'hue-mean': -2.0405545}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x, y = next(iter(dataset))\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:18.099585Z",
     "iopub.status.busy": "2023-12-04T17:47:18.099501Z",
     "iopub.status.idle": "2023-12-04T17:47:18.108487Z",
     "shell.execute_reply": "2023-12-04T17:47:18.108239Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'path'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A multiclass classifier's goal is to learn how to predict a class `y` from a bunch of features `x`. We'll attempt to do this with a decision tree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:18.109928Z",
     "iopub.status.busy": "2023-12-04T17:47:18.109851Z",
     "iopub.status.idle": "2023-12-04T17:47:18.140712Z",
     "shell.execute_reply": "2023-12-04T17:47:18.140335Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import tree\n",
    "\n",
    "model = tree.HoeffdingTreeClassifier()\n",
    "model.predict_proba_one(x)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The reason why the output dictionary is empty is because the model hasn't seen any data yet. It isn't aware of the dataset whatsoever. If this were a binary classifier, then it would output a probability of 50% for `True` and `False` because the classes are implicit. But in this case we're doing multiclass classification.\n",
    "\n",
    "Likewise, the `predict_one` method initially returns `None` because the model hasn't seen any labeled data yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:18.142820Z",
     "iopub.status.busy": "2023-12-04T17:47:18.142689Z",
     "iopub.status.idle": "2023-12-04T17:47:18.152525Z",
     "shell.execute_reply": "2023-12-04T17:47:18.152285Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    }
   ],
   "source": [
    "print(model.predict_one(x))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we update the model and try again, then we see that a probability of 100% is assigned to the `'path'` class because that's the only one the model is aware of."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:18.153876Z",
     "iopub.status.busy": "2023-12-04T17:47:18.153798Z",
     "iopub.status.idle": "2023-12-04T17:47:18.165027Z",
     "shell.execute_reply": "2023-12-04T17:47:18.164772Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'path': 1.0}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.learn_one(x, y)\n",
    "model.predict_proba_one(x)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a strength of online classifiers: they're able to deal with new classes appearing in the data stream.\n",
    "\n",
    "Typically, an online model makes a prediction, and then learns once the ground truth reveals itself. The prediction and the ground truth can be compared to measure the model's correctness. If you have a dataset available, you can loop over it, make a prediction, update the model, and compare the model's output with the ground truth. This is called progressive validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:18.166512Z",
     "iopub.status.busy": "2023-12-04T17:47:18.166430Z",
     "iopub.status.idle": "2023-12-04T17:47:18.610839Z",
     "shell.execute_reply": "2023-12-04T17:47:18.610498Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "            Precision   Recall   F1       Support  \n",
       "                                                   \n",
       "brickface      77.13%   84.85%   80.81%       330  \n",
       "   cement      78.92%   83.94%   81.35%       330  \n",
       "  foliage      65.69%   20.30%   31.02%       330  \n",
       "    grass     100.00%   96.97%   98.46%       330  \n",
       "     path      90.63%   91.19%   90.91%       329  \n",
       "      sky      99.08%   98.18%   98.63%       330  \n",
       "   window      43.50%   67.88%   53.02%       330  \n",
       "                                                   \n",
       "    Macro      79.28%   77.62%   76.31%            \n",
       "    Micro      77.61%   77.61%   77.61%            \n",
       " Weighted      79.27%   77.61%   76.31%            \n",
       "\n",
       "                  77.61% accuracy                  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import metrics\n",
    "\n",
    "model = tree.HoeffdingTreeClassifier()\n",
    "\n",
    "metric = metrics.ClassificationReport()\n",
    "\n",
    "for x, y in dataset:\n",
    "    y_pred = model.predict_one(x)\n",
    "    model.learn_one(x, y)\n",
    "    if y_pred is not None:\n",
    "        metric.update(y, y_pred)\n",
    "    \n",
    "metric"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a common way to evaluate an online model. In fact, there is a dedicated `evaluate.progressive_val_score` function that does this for you."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:18.613061Z",
     "iopub.status.busy": "2023-12-04T17:47:18.612927Z",
     "iopub.status.idle": "2023-12-04T17:47:19.085072Z",
     "shell.execute_reply": "2023-12-04T17:47:19.084781Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "            Precision   Recall   F1       Support  \n",
       "                                                   \n",
       "brickface      77.13%   84.85%   80.81%       330  \n",
       "   cement      78.92%   83.94%   81.35%       330  \n",
       "  foliage      65.69%   20.30%   31.02%       330  \n",
       "    grass     100.00%   96.97%   98.46%       330  \n",
       "     path      90.63%   91.19%   90.91%       329  \n",
       "      sky      99.08%   98.18%   98.63%       330  \n",
       "   window      43.50%   67.88%   53.02%       330  \n",
       "                                                   \n",
       "    Macro      79.28%   77.62%   76.31%            \n",
       "    Micro      77.61%   77.61%   77.61%            \n",
       " Weighted      79.27%   77.61%   76.31%            \n",
       "\n",
       "                  77.61% accuracy                  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import evaluate\n",
    "\n",
    "model = tree.HoeffdingTreeClassifier()\n",
    "metric = metrics.ClassificationReport()\n",
    "\n",
    "evaluate.progressive_val_score(dataset, model, metric)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "e6e87bad9c8c768904c061eafcb4f6739260ff8bb57f302c215ab258ded773dc"
  },
  "kernelspec": {
   "display_name": "Python 3.9.12 ('river')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
